{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5c34c4d-1563-47e1-aaf0-b33635d7e5c7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import math\n",
    "\n",
    "from unet import UNetModel\n",
    "from attr_classifier import FaceAttrModel\n",
    "from diffusion import GaussianDiffusion\n",
    "\n",
    "import tqdm\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f38458b-8449-42d9-b60b-b1c8afaf0956",
   "metadata": {},
   "source": [
    "## Load Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc76d526-8b35-4943-931c-a0f4249dd219",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load DDPM model\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "diff_net = UNetModel(image_size=256, in_channels=3, out_channels=6, \n",
    "                     model_channels=256, num_res_blocks=2, channel_mult=(1, 1, 2, 2, 4, 4),\n",
    "                     attention_resolutions=[32,16,8], num_head_channels=64, dropout=0.1, resblock_updown=True, use_scale_shift_norm=True).to(device)\n",
    "diff_net.load_state_dict(torch.load('models/ffhq.pt'))\n",
    "print('Loaded Diffusion Model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9598ad9-5cb7-4a58-8ea4-75d0d18b50bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "face_attributes = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', \n",
    "                   'Bangs', 'Big_Lips', 'Big_Nose','Black_Hair', 'Blond_Hair',\n",
    "                   'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', \n",
    "                   'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', \n",
    "                   'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', \n",
    "                   'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', \n",
    "                   'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', \n",
    "                   'Wearing_Hat','Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young' \n",
    "]\n",
    "\n",
    "# Load face attribute model\n",
    "attr_model = FaceAttrModel(pretrained=False, selected_attrs=face_attributes).to(device)\n",
    "attr_model.load_state_dict(torch.load('models/Resnet18.pth', map_location='cuda:0'))\n",
    "print('Loaded attributes model')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dc9763d-ef67-4211-a51f-18606f665211",
   "metadata": {},
   "source": [
    "## Specify Target Attributes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7235b819-db80-4040-ade8-bf762761e61c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# MSE loss - Which attributes to consider\n",
    "# Attribute values\n",
    "target_attributes = {\n",
    "    '5_o_Clock_Shadow': 0,\n",
    "    'Arched_Eyebrows': 0,\n",
    "    'Attractive': 0,\n",
    "    'Bags_Under_Eyes': 0,\n",
    "    'Bald': 1,\n",
    "    'Bangs': 0,\n",
    "    'Big_Lips': 0,\n",
    "    'Big_Nose': 0,\n",
    "    'Black_Hair': 1,\n",
    "    'Blond_Hair': 1,\n",
    "    'Blurry': 0,\n",
    "    'Brown_Hair': 0,\n",
    "    'Bushy_Eyebrows': 0,\n",
    "    'Chubby': 0,\n",
    "    'Double_Chin': 0,\n",
    "    'Eyeglasses': 1,\n",
    "    'Goatee': 1,\n",
    "    'Gray_Hair': 0,\n",
    "    'Heavy_Makeup': 0,\n",
    "    'High_Cheekbones': 0,\n",
    "    'Male': 1,\n",
    "    'Mouth_Slightly_Open': 0,\n",
    "    'Mustache': 1,\n",
    "    'Narrow_Eyes': 0,\n",
    "    'No_Beard': 0,\n",
    "    'Oval_Face': 0,\n",
    "    'Pale_Skin': 0,\n",
    "    'Pointy_Nose': 0,\n",
    "    'Receding_Hairline': 0,\n",
    "    'Rosy_Cheeks': 1,\n",
    "    'Sideburns': 0,\n",
    "    'Smiling': 1,\n",
    "    'Straight_Hair': 0,\n",
    "    'Wavy_Hair': 1,\n",
    "    'Wearing_Earrings': 0,\n",
    "    'Wearing_Hat': 0,\n",
    "    'Wearing_Lipstick': 0,\n",
    "    'Wearing_Necklace': 0,\n",
    "    'Wearing_Necktie': 0,\n",
    "    'Young': 1, # 0 corresponds to young, 1 corresponds to old\n",
    "}\n",
    "mask = {key: 0 for key in target_attributes.keys()}\n",
    "mask['Young'] = 1\n",
    "mask['Goatee'] = 1\n",
    "\n",
    "# Logit loss - Which attributes to minimize/maximize (-1/1) \n",
    "mask = {key: 0 for key in target_attributes.keys()}\n",
    "mask['Young'] = 1\n",
    "mask['Goatee'] = 1\n",
    "\n",
    "# Convert to tensors\n",
    "target_attributes = torch.tensor([target_attributes[face_attributes[i]] for i in range(len(face_attributes))]).view(1,len(face_attributes)).float().to(device)\n",
    "mask = torch.tensor([mask[face_attributes[i]] for i in range(len(face_attributes))]).view(1,len(face_attributes)).float().to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ee8ca18-4b71-4bd4-9323-80142df92513",
   "metadata": {},
   "source": [
    "## Run Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7ad4183-d66f-4512-a9ea-7c35964303d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "\n",
    "steps = 200\n",
    "t_vals = []\n",
    "for i in range(steps):  \n",
    "    t = ((steps-i)/1.5 + (steps-i)/3*math.cos(i/10))/steps*800 + 200 # Linearly decreasing + cosine\n",
    "    \n",
    "    # Additional: Add noise to t\n",
    "    t = np.array([t + np.random.randint(-50, 51) for _ in range(1)]).astype(int)\n",
    "    t = np.clip(t, 1, diffusion.T)\n",
    "\n",
    "    t_vals.append(t[0])\n",
    "    \n",
    "plt.figure(figsize=(8,5))\n",
    "plt.plot(range(steps), t_vals, linewidth=2)\n",
    "plt.title('$t$ Annealing Schedule')\n",
    "plt.xlabel('Steps')\n",
    "plt.ylabel('$t$')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "384b9b4d-275f-433e-9694-f0ce7969be9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transforms to apply to attribute classifier input\n",
    "def transform_cls_input(img, prob=0.5):    \n",
    "    # Horizontal flip\n",
    "    p = np.random.rand()\n",
    "    if p < prob:\n",
    "        img = torchvision.transforms.functional.hflip(img)\n",
    "    \n",
    "    # Brightness perturbation\n",
    "    p = np.random.rand()\n",
    "    if p < prob:\n",
    "        b = np.random.rand() * (0.2 + 0.2) - 0.2\n",
    "        img = img*(1+b)\n",
    "    \n",
    "    # Blur\n",
    "    p = np.random.rand()\n",
    "    if p < prob:\n",
    "        sigma = np.random.rand() * 5\n",
    "        img = torchvision.transforms.functional.gaussian_blur(img, 7, sigma)\n",
    "        \n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa7a50ef-99c7-4515-bf8a-b53c42e7d9ab",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class InferenceModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(InferenceModel, self).__init__()\n",
    "        # Inferred image\n",
    "        self.img = nn.Parameter(torch.randn(1,3,256,256))\n",
    "        self.img.requires_grad = True\n",
    "\n",
    "    def encode(self):\n",
    "        return self.img\n",
    "model = InferenceModel().to(device)\n",
    "model.train()\n",
    "\n",
    "# Inference procedure steps\n",
    "steps = 200   \n",
    "\n",
    "opt = torch.optim.Adamax(model.parameters(), lr=1)\n",
    "# Optional: Linearly decrease learning rate\n",
    "scheduler = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1, end_factor=1, total_iters=steps)\n",
    "\n",
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "diff_net.eval()\n",
    "attr_model.eval()\n",
    "\n",
    "norm_track = 0\n",
    "bar = tqdm.tqdm(range(steps))\n",
    "losses = []\n",
    "update_every = 10\n",
    "for i, _ in enumerate(bar):  \n",
    "    # Select t      \n",
    "    t = ((steps-i)/1.5 + (steps-i)/3*math.cos(i/10))/steps*800 + 200 # Linearly decreasing + cosine\n",
    "    t = np.array([t + np.random.randint(-50, 51) for _ in range(1)]).astype(int) # Add noise to t\n",
    "    t = np.clip(t, 1, diffusion.T)\n",
    "       \n",
    "    # Denoise\n",
    "    sample_img = model.encode()\n",
    "    xt, epsilon = diffusion.sample(sample_img, t)       \n",
    "    t = torch.from_numpy(t).float().view(1)    \n",
    "    pred = diff_net(xt.float(), t.to(device))   \n",
    "    epsilon_pred = pred[:,:3,:,:] # Use predicted noise only\n",
    "    \n",
    "    # Compute diffusion loss\n",
    "    loss = F.mse_loss(epsilon_pred, epsilon) \n",
    "    \n",
    "    # Compute EMA of diffusion loss gradient norm\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        grad_norm = torch.linalg.norm(model.img.grad)\n",
    "        if i > 0:\n",
    "            alpha = 0.5\n",
    "            norm_track = alpha*norm_track + (1-alpha)*grad_norm\n",
    "        else:\n",
    "            norm_track = grad_norm\n",
    "            \n",
    "    opt.step()\n",
    "    \n",
    "    # Evaluate attribute classifier on batch of randomly transformed inputs\n",
    "    attr_batch_size = 8\n",
    "    attr_input_batch = []\n",
    "    for j in range(attr_batch_size):\n",
    "        attr_input = 0.5*(model.encode()+1)\n",
    "        attr_input = transform_cls_input(attr_input, prob=0.5)\n",
    "        attr_input = torch.clip(attr_input, 0, 1)\n",
    "        attr_input = F.interpolate(attr_input, (224,224), mode='nearest')\n",
    "        attr_input_batch.append(attr_input)\n",
    "        \n",
    "    attr_input_batch = torch.cat(attr_input_batch, dim=0)\n",
    "    attr = attr_model(attr_input_batch)\n",
    "\n",
    "    # MSE between predicted and target attributes\n",
    "    #loss = torch.sum(F.mse_loss(torch.sigmoid(attr), target_attributes.tile(attr_batch_size,1), reduction='none')*mask) / mask.sum() / attr_batch_size\n",
    "\n",
    "    # Maximize/Minimize attribute logits\n",
    "    loss = -torch.sum(attr*mask.tile(attr_batch_size, 1)) / torch.abs(mask).sum() / attr_batch_size\n",
    "    \n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    # Clip attribute loss gradients\n",
    "    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1*norm_track)\n",
    "    opt.step()\n",
    "    scheduler.step()\n",
    "\n",
    "    losses.append(loss.item())\n",
    "    if i % update_every == 0:\n",
    "        bar.set_postfix({'Loss': np.mean(losses)})\n",
    "        losses = []    \n",
    "        \n",
    "    # Visualize inferred image\n",
    "    if (i+1) % 25 == 0 or i == 0:\n",
    "        with torch.no_grad():\n",
    "            fig, ax = plt.subplots(1, 1, figsize=(10,5))\n",
    "            ax.imshow(0.5*(model.encode()+1)[0].cpu().numpy().transpose([1,2,0]))\n",
    "            ax.set_title('Inferred Image')\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ff86e23-ef32-4d38-975f-96e6623f820f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pass inferred image through attribute classifier network\n",
    "with torch.no_grad():\n",
    "    attr_input = 0.5*(model.encode()+1)\n",
    "    #attr_input = transform_cls_input(attr_input, 0.5)\n",
    "    attr_input = torch.clip(attr_input, 0, 1)\n",
    "    attr_input = F.interpolate(attr_input, (224,224), mode='nearest')\n",
    "\n",
    "    attr = attr_model(attr_input)\n",
    "    attr = torch.sigmoid(attr)\n",
    "\n",
    "# Print results\n",
    "print(f'{\"Attributes \":20s} | {\"Predicted\":10s} | {\"Target\":10s} | {\"Mask\":10s}')\n",
    "print(f'{\"-\"*53}')\n",
    "for j in range(len(face_attributes)):\n",
    "    print(f'{face_attributes[j]:20s} | {attr[0,j].item():.2f} {\" \":5s} | {target_attributes[0,j].item():.2f} {\" \":5s} | {mask[0,j].item():.2f}')\n",
    "    \n",
    "fig, ax = plt.subplots(1, 1, figsize=(10,5))\n",
    "ax.imshow(attr_input[0].cpu().numpy().transpose([1,2,0]))\n",
    "ax.set_title('Classifier Input')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ec594b3-ede2-436b-8b54-29fd41784837",
   "metadata": {},
   "source": [
    "## Fine-tune Image with Denoising Steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c8924d9-8169-41fe-8e53-b82438e3256e",
   "metadata": {},
   "outputs": [],
   "source": [
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "\n",
    "start_t = 200\n",
    "steps = start_t\n",
    "x = model.encode()\n",
    "\n",
    "with torch.no_grad():\n",
    "    diff_net.eval()\n",
    "    fine_tuned = diffusion.inverse(diff_net, shape=(3,256,256), start_t=start_t, steps=steps, x=x, device=device)\n",
    "    diff_net.train()\n",
    "    \n",
    "fig, ax = plt.subplots(1, 1, figsize=(5,5))\n",
    "ax.imshow(0.5*(fine_tuned+1)[0].cpu().numpy().transpose([1,2,0]))\n",
    "ax.set_title(f'Fine-tuned Sample | $t_{0}$={start_t} steps={steps}')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f387afb-b4f5-4da0-bea2-b3f53d5bf660",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pass fine-tuned image through attribute classifier network\n",
    "with torch.no_grad():\n",
    "    attr_input = 0.5*(fine_tuned+1)\n",
    "    #attr_input = transform_cls_input(attr_input, 0.5)\n",
    "    #attr_input = torch.clip(attr_input, 0, 1)\n",
    "    attr_input = F.interpolate(attr_input, (224,224), mode='nearest')\n",
    "\n",
    "    attr = attr_model(attr_input)\n",
    "    attr = torch.sigmoid(attr)\n",
    "\n",
    "# Print results\n",
    "print(f'{\"Attributes \":20s} | {\"Predicted\":10s} | {\"Target\":10s} | {\"Mask\":10s}')\n",
    "print(f'{\"-\"*53}')\n",
    "for j in range(len(face_attributes)):\n",
    "    print(f'{face_attributes[j]:20s} | {attr[0,j].item():.2f} {\" \":5s} | {target_attributes[0,j].item():.2f} {\" \":5s} | {mask[0,j].item():.2f}')\n",
    "    \n",
    "fig, ax = plt.subplots(1, 1, figsize=(10,5))\n",
    "ax.imshow(attr_input[0].cpu().numpy().transpose([1,2,0]))\n",
    "ax.set_title('Classifier Input')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bb3df51-fc9b-4c93-ae6d-b02ac9cc089e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch_geo_simple",
   "language": "python",
   "name": "conda-env-torch_geo_simple-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
